Skip to content

Conversation

@ricardoV94
Copy link
Member

Related to #1806 #1827

Fix bug when passing simple Tensor shape to split_dims
Change grad_undefined -> grad_disconnected for split_sizes in SplitOp (see #1827 for more context)

):
# All elements already have the right number of dimensions, so we
# can just join them directly.
return join(0, *x)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't equivalent to stack below?

Copy link
Member

@jessegrabowski jessegrabowski Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because stack adds a dimension. This was causing a bug in split_dims where we ask explicitly ask for ndims=1, passing a sequence of 1d tensors, but then we get back a 2d tensor.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, so my understand is this function is supposed to do what np.array(x) would do. I think the ndim is more of an assert, it should fail when the output of np.array (in our case the symbolic equivalent) would yield something different. So in that sense join is never valid as it keeps the same dimensions.

I want to revert and check if I'm missing something with the test that was failing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. From my perspective the biggest issue is that as_tensor_variable(..., ndims=1) isn't idempotent -- sequential calls on the same input keep mutating the same graph. This is happening because of stack.

Copy link
Member Author

@ricardoV94 ricardoV94 Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's odd because if it's already a single tensor variable (and not a list with one in it) it shouldn't do anything

Copy link
Member Author

@ricardoV94 ricardoV94 Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that first one seems wrong.

Even if fix it, I think our check for "sequence" on split_dims (or wherever the problem was) should be more like if isinstance(x, Sequence) or (isinstance(x, TensorVariable) and x.ndim == 1)

1d numpy arrays should also be valid, but maybe those pass the Sequence instance check.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should remove the ndim argument altogether? numpy doesn't have it and I don't think we need it.

I thought it was just used for validation but it seems to affect non-raising outcomes

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should remove the ndim argument altogether? numpy doesn't have it and I don't think we need it.

I thought it was just used for validation but it seems to affect non-raising outcomes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm +1 for removing it. I never knew it existed, and it seems like it's overloading the function.

If I had to guess though, it's exactly for this situation. We have an argument with type int | Variable | tuple[int | Variable]. The Variable, though, can be either a scalar or an array. So really the typing is something like int | Variable[ndim=0] | Variable[ndim=1] | tuple[int | Variable[ndim=0]. When we do the if not isinstance(shape, tuple): shape = (shape, ) we're ignoring the Variable[ndim=1] case. Calling as_tensor_variable(tuple[Variable[ndim=0]) -> Variable[ndim=1] makes sense to me, and matches the numpy behavior. In this case we're counting on the ndim=1 arugment to guard against the case of as_tensor_variable(tuple[Variable[ndim=1]) -> Variable[ndim=2].

Typing all this out, it seems like an abuse of the as_tensor_variable function.

Copy link
Member Author

@ricardoV94 ricardoV94 Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agreed. Would be really nice to be able to have those TensorVariable[ndim=0] types btw. Need to nerdsnipe some type hint lovers

@jessegrabowski
Copy link
Member

I reverted the changes to as_tensor_variable. At minimum it's out of scope for this PR. Implementing more careful checks of the shape argument (based on the analysis in the comment above) was sufficient to clear the test failures. We can revisit the ndims argument later.

Something else I noticed was that we're passing dtype to as_tensor_variable. This doesn't do anything in the Variable case, so I changed it to an explicit cast (inside the Op make_node, I left it in the wrapper to handle the Sequence case)

@ricardoV94
Copy link
Member Author

No, better not to cast variables in node but raise like before. That's what shape ops always do. If a user passes a float as a shape argument it's likely a bug and this would mask it

@jessegrabowski
Copy link
Member

Someday I will merge a PR

)

if not shape:
if empty_shape:
Copy link
Member Author

@ricardoV94 ricardoV94 Jan 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about just shape.type.shape == (0,), for the variable case? Also if you standardize as_tensor_variable you don't need the variable vs non-variable case

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But also do we need the special squeeze branch or would the Op do the right thing anyway?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests pass without it (as long as I adjust the existing test_split_size_zero_shape test to pass dtype int to the shape argument), so I guess not.

@ricardoV94
Copy link
Member Author

I'm happy with the PR. I'll fix the git history and merge

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants